{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "source": [
    "# Copyright (c) Facebook, Inc. and its affiliates.\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#    http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Bert Pipeline : PyTorch BERT News Classfication\n",
    "\n",
    "This notebook shows PyTorch BERT end-to-end news classification example using Kubeflow Pipelines.\n",
    "\n",
    "\n",
    "An example notebook that demonstrates how to:\n",
    "\n",
    "* Get different tasks needed for the pipeline\n",
    "* Create a Kubeflow pipeline\n",
    "* Include Pytorch KFP components to preprocess, train, visualize and deploy the model in the pipeline\n",
    "* Submit a job for execution\n",
    "* Query(prediction and explain) the final deployed model\n",
    "* Interpretation of the model using the Captum Insights\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "! pip uninstall -y kfp\n",
    "! pip install --no-cache-dir kfp"
   ],
   "outputs": [],
   "metadata": {
    "tags": []
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "import kfp\n",
    "import json\n",
    "import os\n",
    "from kfp.onprem import use_k8s_secret\n",
    "from kfp import components\n",
    "from kfp.components import load_component_from_file, load_component_from_url\n",
    "from kfp import dsl\n",
    "from kfp import compiler\n",
    "\n",
    "kfp.__version__"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "'1.6.3'"
      ]
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Enter your gateway and the cookie\n",
    "[Use this extension on chrome to get token]( https://chrome.google.com/webstore/detail/editthiscookie/fngmhnnpilhplaeedifhccceomclgfbg?hl=en)\n",
    "\n",
    "![image.png](./image.png)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Update values for the ingress gateway and auth session"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "INGRESS_GATEWAY='http://istio-ingressgateway.istio-system.svc.cluster.local'\n",
    "AUTH=\"<enter your token here>\"\n",
    "NAMESPACE=\"kubeflow-user-example-com\"\n",
    "COOKIE=\"authservice_session=\"+AUTH\n",
    "EXPERIMENT=\"Default\""
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Set Log bucket and Tensorboard Image"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "MINIO_ENDPOINT=\"http://minio-service.kubeflow:9000\"\n",
    "LOG_BUCKET=\"mlpipeline\"\n",
    "TENSORBOARD_IMAGE=\"public.ecr.aws/pytorch-samples/tboard:latest\""
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "client = kfp.Client(host=INGRESS_GATEWAY+\"/pipeline\", cookies=COOKIE)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "client.create_experiment(EXPERIMENT)\n",
    "experiments = client.list_experiments(namespace=NAMESPACE)\n",
    "my_experiment = experiments.experiments[0]\n",
    "my_experiment"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ],
      "text/html": [
       "<a href=\"http://istio-ingressgateway.istio-system.svc.cluster.local/pipeline/#/experiments/details/aac96a63-616e-4d88-9334-6ca8df2bb956\" target=\"_blank\" >Experiment details</a>."
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'created_at': datetime.datetime(2021, 4, 22, 8, 44, 39, tzinfo=tzlocal()),\n",
       " 'description': None,\n",
       " 'id': 'aac96a63-616e-4d88-9334-6ca8df2bb956',\n",
       " 'name': 'Default',\n",
       " 'resource_references': [{'key': {'id': 'kubeflow-user-example-com',\n",
       "                                  'type': 'NAMESPACE'},\n",
       "                          'name': None,\n",
       "                          'relationship': 'OWNER'}],\n",
       " 'storage_state': 'STORAGESTATE_AVAILABLE'}"
      ]
     },
     "metadata": {},
     "execution_count": 6
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Set Inference parameters"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "DEPLOY_NAME=\"bertserve\"\n",
    "MODEL_NAME=\"bert\""
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "prepare_tensorboard_op = load_component_from_file(\"common/tensorboard/component.yaml\")\n",
    "prep_op = components.load_component_from_file(\n",
    "    \"bert/yaml/pre_process/component.yaml\"\n",
    ")\n",
    "train_op = components.load_component_from_file(\n",
    "    \"bert/yaml/train/component.yaml\"\n",
    ")\n",
    "deploy_op = load_component_from_file(\"common/deploy/component.yaml\")\n",
    "minio_op = components.load_component_from_file(\n",
    "    \"common/minio/component.yaml\"\n",
    ")"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Define pipeline"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "source": [
    "@dsl.pipeline(name=\"Training pipeline\", description=\"Sample training job test\")\n",
    "def pytorch_bert( # pylint: disable=too-many-arguments\n",
    "    minio_endpoint=MINIO_ENDPOINT,\n",
    "    log_bucket=LOG_BUCKET,\n",
    "    log_dir=f\"tensorboard/logs/{dsl.RUN_ID_PLACEHOLDER}\",\n",
    "    mar_path=f\"mar/{dsl.RUN_ID_PLACEHOLDER}/model-store\",\n",
    "    config_prop_path=f\"mar/{dsl.RUN_ID_PLACEHOLDER}/config\",\n",
    "    model_uri=f\"s3://mlpipeline/mar/{dsl.RUN_ID_PLACEHOLDER}\",\n",
    "    tf_image=TENSORBOARD_IMAGE,\n",
    "    deploy=DEPLOY_NAME,\n",
    "    namespace=NAMESPACE,\n",
    "    confusion_matrix_log_dir=f\"confusion_matrix/{dsl.RUN_ID_PLACEHOLDER}/\",\n",
    "    num_samples=1000,\n",
    "    max_epochs=1\n",
    "):\n",
    "    \"\"\"Thid method defines the pipeline tasks and operations\"\"\"\n",
    "    prepare_tb_task = prepare_tensorboard_op(\n",
    "        log_dir_uri=f\"s3://{log_bucket}/{log_dir}\",\n",
    "        image=tf_image,\n",
    "        pod_template_spec=json.dumps({\n",
    "            \"spec\": {\n",
    "                \"containers\": [{\n",
    "                    \"env\": [\n",
    "                        {\n",
    "                            \"name\": \"AWS_ACCESS_KEY_ID\",\n",
    "                            \"valueFrom\": {\n",
    "                                \"secretKeyRef\": {\n",
    "                                    \"name\": \"mlpipeline-minio-artifact\",\n",
    "                                    \"key\": \"accesskey\",\n",
    "                                }\n",
    "                            },\n",
    "                        },\n",
    "                        {\n",
    "                            \"name\": \"AWS_SECRET_ACCESS_KEY\",\n",
    "                            \"valueFrom\": {\n",
    "                                \"secretKeyRef\": {\n",
    "                                    \"name\": \"mlpipeline-minio-artifact\",\n",
    "                                    \"key\": \"secretkey\",\n",
    "                                }\n",
    "                            },\n",
    "                        },\n",
    "                        {\n",
    "                            \"name\": \"AWS_REGION\",\n",
    "                            \"value\": \"minio\"\n",
    "                        },\n",
    "                        {\n",
    "                            \"name\": \"S3_ENDPOINT\",\n",
    "                            \"value\": f\"{minio_endpoint}\",\n",
    "                        },\n",
    "                        {\n",
    "                            \"name\": \"S3_USE_HTTPS\",\n",
    "                            \"value\": \"0\"\n",
    "                        },\n",
    "                        {\n",
    "                            \"name\": \"S3_VERIFY_SSL\",\n",
    "                            \"value\": \"0\"\n",
    "                        },\n",
    "                    ]\n",
    "                }]\n",
    "            }\n",
    "        }),\n",
    "    ).set_display_name(\"Visualization\")\n",
    "\n",
    "    prep_task = (\n",
    "        prep_op().after(prepare_tb_task\n",
    "                       ).set_display_name(\"Preprocess & Transform\")\n",
    "    )\n",
    "    confusion_matrix_url = f\"minio://{log_bucket}/{confusion_matrix_log_dir}\"\n",
    "    script_args = f\"model_name=bert.pth,\" \\\n",
    "                  f\"num_samples={num_samples},\" \\\n",
    "                  f\"confusion_matrix_url={confusion_matrix_url}\"\n",
    "    # For GPU , set gpus count and accelerator type\n",
    "    ptl_args = f\"max_epochs={max_epochs},profiler=pytorch,gpus=0,accelerator=None\"\n",
    "    train_task = (\n",
    "        train_op(\n",
    "            input_data=prep_task.outputs[\"output_data\"],\n",
    "            bert_script_args=script_args,\n",
    "            ptl_arguments=ptl_args\n",
    "        ).after(prep_task).set_display_name(\"Training\")\n",
    "    )\n",
    "    # For GPU uncomment below line and set GPU limit and node selector\n",
    "    # ).set_gpu_limit(1).add_node_selector_constraint\n",
    "    # ('cloud.google.com/gke-accelerator','nvidia-tesla-p4')\n",
    "\n",
    "    (\n",
    "        minio_op(\n",
    "            bucket_name=\"mlpipeline\",\n",
    "            folder_name=log_dir,\n",
    "            input_path=train_task.outputs[\"tensorboard_root\"],\n",
    "            filename=\"\",\n",
    "        ).after(train_task).set_display_name(\"Tensorboard Events Pusher\")\n",
    "    )\n",
    "    minio_mar_upload = (\n",
    "        minio_op(\n",
    "            bucket_name=\"mlpipeline\",\n",
    "            folder_name=mar_path,\n",
    "            input_path=train_task.outputs[\"checkpoint_dir\"],\n",
    "            filename=\"bert_test.mar\",\n",
    "        ).after(train_task).set_display_name(\"Mar Pusher\")\n",
    "    )\n",
    "    (\n",
    "        minio_op(\n",
    "            bucket_name=\"mlpipeline\",\n",
    "            folder_name=config_prop_path,\n",
    "            input_path=train_task.outputs[\"checkpoint_dir\"],\n",
    "            filename=\"config.properties\",\n",
    "        ).after(train_task).set_display_name(\"Conifg Pusher\")\n",
    "    )\n",
    "\n",
    "    model_uri = str(model_uri)\n",
    "    # pylint: disable=unused-variable\n",
    "    isvc_yaml = \"\"\"\n",
    "    apiVersion: \"serving.kubeflow.org/v1beta1\"\n",
    "    kind: \"InferenceService\"\n",
    "    metadata:\n",
    "      name: {}\n",
    "      namespace: {}\n",
    "    spec:\n",
    "      predictor:\n",
    "        serviceAccountName: sa\n",
    "        pytorch:\n",
    "          storageUri: {}\n",
    "          resources:\n",
    "            requests: \n",
    "              cpu: 4\n",
    "              memory: 8Gi\n",
    "            limits:\n",
    "              cpu: 4\n",
    "              memory: 8Gi\n",
    "    \"\"\".format(deploy, namespace, model_uri)\n",
    "\n",
    "    # For GPU inference use below yaml with gpu count and accelerator\n",
    "    gpu_count = \"1\"\n",
    "    accelerator = \"nvidia-tesla-p4\"\n",
    "    isvc_gpu_yaml = \"\"\"\n",
    "    apiVersion: \"serving.kubeflow.org/v1beta1\"\n",
    "    kind: \"InferenceService\"\n",
    "    metadata:\n",
    "      name: {}\n",
    "      namespace: {}\n",
    "    spec:\n",
    "      predictor:\n",
    "        serviceAccountName: sa\n",
    "        pytorch:\n",
    "          storageUri: {}\n",
    "          resources:\n",
    "            requests: \n",
    "              cpu: 4\n",
    "              memory: 8Gi\n",
    "            limits:\n",
    "              cpu: 4\n",
    "              memory: 8Gi\n",
    "              nvidia.com/gpu: {}\n",
    "          nodeSelector:\n",
    "            cloud.google.com/gke-accelerator: {}\n",
    "\"\"\".format(deploy, namespace, model_uri, gpu_count, accelerator)\n",
    "    # Update inferenceservice_yaml for GPU inference\n",
    "    deploy_task = (\n",
    "        deploy_op(action=\"apply\", inferenceservice_yaml=isvc_yaml\n",
    "                 ).after(minio_mar_upload).set_display_name(\"Deployer\")\n",
    "    )\n",
    "\n",
    "    dsl.get_pipeline_conf().add_op_transformer(\n",
    "        use_k8s_secret(\n",
    "            secret_name=\"mlpipeline-minio-artifact\",\n",
    "            k8s_secret_key_to_env={\n",
    "                \"secretkey\": \"MINIO_SECRET_KEY\",\n",
    "                \"accesskey\": \"MINIO_ACCESS_KEY\",\n",
    "            },\n",
    "        )\n",
    "    )"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "source": [
    "# Compile pipeline\n",
    "compiler.Compiler().compile(pytorch_bert, 'pytorch.tar.gz', type_check=True)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "source": [
    "# Execute pipeline\n",
    "run = client.run_pipeline(my_experiment.id, 'pytorch-bert', 'pytorch.tar.gz')"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ],
      "text/html": [
       "<a href=\"http://istio-ingressgateway.istio-system.svc.cluster.local/pipeline/#/runs/details/308b618c-38d0-43d6-ae6d-3b0d11418f07\" target=\"_blank\" >Run details</a>."
      ]
     },
     "metadata": {}
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Wait for inference service below to go to `READY True` state."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "source": [
    "!kubectl get isvc $DEPLOY"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "NAME         URL                                                       READY   PREV   LATEST   PREVROLLEDOUTREVISION   LATESTREADYREVISION                  AGE\n",
      "bertserve    http://bertserve.kubeflow-user-example-com.example.com    True           100                              bertserve-predictor-default-zckhh    44h\n",
      "torchserve   http://torchserve.kubeflow-user-example-com.example.com   True           100                              torchserve-predictor-default-tdknk   47h\n"
     ]
    }
   ],
   "metadata": {
    "scrolled": true
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Get Inferenceservice name"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "source": [
    "INFERENCE_SERVICE_LIST = ! kubectl get isvc {DEPLOY_NAME} -n {NAMESPACE} -o json | python3 -c \"import sys, json; print(json.load(sys.stdin)['status']['url'])\"| tr -d '\"' | cut -d \"/\" -f 3\n",
    "INFERENCE_SERVICE_NAME = INFERENCE_SERVICE_LIST[0]\n",
    "INFERENCE_SERVICE_NAME"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "'bertserve.kubeflow-user-example-com.example.com'"
      ]
     },
     "metadata": {},
     "execution_count": 14
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Prediction Request"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "source": [
    "!curl -v -H \"Host: $INFERENCE_SERVICE_NAME\" -H \"Cookie: $COOKIE\" \"$INGRESS_GATEWAY/v1/models/$MODEL_NAME:predict\" -d @./bert/sample.txt > bert_prediction_output.json"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0*   Trying 10.100.251.14:80...\n",
      "* TCP_NODELAY set\n",
      "* Connected to istio-ingressgateway.istio-system.svc.cluster.local (10.100.251.14) port 80 (#0)\n",
      "> POST /v1/models/bert:predict HTTP/1.1\n",
      "> Host: bertserve.kubeflow-user-example-com.example.com\n",
      "> User-Agent: curl/7.68.0\n",
      "> Accept: */*\n",
      "> Cookie: authservice_session=MTYyMzI1NDI0NHxOd3dBTkVGU1UxaElXRXN5VUVKTVJrZFVUMWhDU1VoVlNVMUhSRFZaVVRWQlNrVkhORTAzUTFWTFVqZExSa0pHVmpWU016SmFOa0U9fO86sBQIDoqYUxX9ffUnG7xS8xyysaWppWJa0c3QBRJd\n",
      "> Content-Length: 84\n",
      "> Content-Type: application/x-www-form-urlencoded\n",
      "> \n",
      "} [84 bytes data]\n",
      "* upload completely sent off: 84 out of 84 bytes\n",
      "* Mark bundle as not supporting multiuse\n",
      "< HTTP/1.1 200 OK\n",
      "< content-length: 33\n",
      "< content-type: application/json; charset=UTF-8\n",
      "< date: Thu, 10 Jun 2021 08:51:03 GMT\n",
      "< server: istio-envoy\n",
      "< x-envoy-upstream-service-time: 535\n",
      "< \n",
      "{ [33 bytes data]\n",
      "100   117  100    33  100    84     52    132 --:--:-- --:--:-- --:--:--   214\n",
      "* Connection #0 to host istio-ingressgateway.istio-system.svc.cluster.local left intact\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "source": [
    "! cat bert_prediction_output.json"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{\"predictions\": [\"\\\"Sci/Tech\\\"\"]}"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Explanation Request"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "source": [
    "!curl -v -H \"Host: $INFERENCE_SERVICE_NAME\" -H \"Cookie: $COOKIE\" \"$INGRESS_GATEWAY/v1/models/$MODEL_NAME:explain\" -d @./bert/sample.txt  > bert_explaination_output.json"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0*   Trying 10.100.251.14:80...\n",
      "* TCP_NODELAY set\n",
      "* Connected to istio-ingressgateway.istio-system.svc.cluster.local (10.100.251.14) port 80 (#0)\n",
      "> POST /v1/models/bert:explain HTTP/1.1\n",
      "> Host: bertserve.kubeflow-user-example-com.example.com\n",
      "> User-Agent: curl/7.68.0\n",
      "> Accept: */*\n",
      "> Cookie: authservice_session=MTYyMzI1NDI0NHxOd3dBTkVGU1UxaElXRXN5VUVKTVJrZFVUMWhDU1VoVlNVMUhSRFZaVVRWQlNrVkhORTAzUTFWTFVqZExSa0pHVmpWU016SmFOa0U9fO86sBQIDoqYUxX9ffUnG7xS8xyysaWppWJa0c3QBRJd\n",
      "> Content-Length: 84\n",
      "> Content-Type: application/x-www-form-urlencoded\n",
      "> \n",
      "} [84 bytes data]\n",
      "* upload completely sent off: 84 out of 84 bytes\n",
      "100    84    0     0  100    84      0      1  0:01:24  0:00:52  0:00:32     000:02     00:30  0:00:12     08  0:00:04     0  0   0      1  0:01:24  0:00:45  0:00:39     0* Mark bundle as not supporting multiuse\n",
      "< HTTP/1.1 200 OK\n",
      "< content-length: 320\n",
      "< content-type: application/json; charset=UTF-8\n",
      "< date: Thu, 10 Jun 2021 08:23:36 GMT\n",
      "< server: istio-envoy\n",
      "< x-envoy-upstream-service-time: 52441\n",
      "< \n",
      "{ [320 bytes data]\n",
      "100   404  100   320  100    84      6      1  0:01:24  0:00:52  0:00:32    77\n",
      "* Connection #0 to host istio-ingressgateway.istio-system.svc.cluster.local left intact\n"
     ]
    }
   ],
   "metadata": {
    "scrolled": true
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "source": [
    "! cat bert_explaination_output.json"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{\"explanations\": [{\"words\": [\"[CLS]\", \"bloomberg\", \"has\", \"reported\", \"on\", \"the\", \"economy\", \"[SEP]\"], \"importances\": [0.49803253502586686, -0.042289041470624116, -0.2269101439114476, 0.15573707990586028, 0.0867725310070807, 0.17919607383818434, 0.5255456841947312, -0.5988271940782108], \"delta\": 0.12081503337965546}]}"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "source": [
    "explanations_json = json.loads(open(\"./bert_explaination_output.json\", \"r\").read())\n",
    "explanations_json"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'explanations': [{'words': ['[CLS]',\n",
       "    'bloomberg',\n",
       "    'has',\n",
       "    'reported',\n",
       "    'on',\n",
       "    'the',\n",
       "    'economy',\n",
       "    '[SEP]'],\n",
       "   'importances': [0.49803253502586686,\n",
       "    -0.042289041470624116,\n",
       "    -0.2269101439114476,\n",
       "    0.15573707990586028,\n",
       "    0.0867725310070807,\n",
       "    0.17919607383818434,\n",
       "    0.5255456841947312,\n",
       "    -0.5988271940782108],\n",
       "   'delta': 0.12081503337965546}]}"
      ]
     },
     "metadata": {},
     "execution_count": 19
    }
   ],
   "metadata": {
    "scrolled": true,
    "tags": []
   }
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "source": [
    "prediction_json = json.loads(open(\"./bert_prediction_output.json\", \"r\").read())"
   ],
   "outputs": [],
   "metadata": {
    "tags": []
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "source": [
    "import torch\n",
    "attributions = explanations_json[\"explanations\"][0]['importances']\n",
    "tokens = explanations_json[\"explanations\"][0]['words']\n",
    "delta = explanations_json[\"explanations\"][0]['delta']\n",
    "\n",
    "attributions = torch.tensor(attributions)\n",
    "pred_prob = 0.75\n",
    "pred_class = prediction_json[\"predictions\"][0]\n",
    "true_class = \"Business\"\n",
    "attr_class =\"world\""
   ],
   "outputs": [],
   "metadata": {
    "tags": []
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Visualization of Predictions"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "source": [
    "from captum.attr import visualization\n",
    "vis_data_records =[]\n",
    "vis_data_records.append(visualization.VisualizationDataRecord(\n",
    "                            attributions,\n",
    "                            pred_prob,\n",
    "                            pred_class,\n",
    "                            true_class,\n",
    "                            attr_class,\n",
    "                            attributions.sum(),       \n",
    "                            tokens,\n",
    "                            delta))"
   ],
   "outputs": [],
   "metadata": {
    "tags": []
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "source": [
    "vis = visualization.visualize_text(vis_data_records)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### visualization appreas as below\n",
    "![viz1.png](./viz1.png)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Cleanup Script"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "source": [
    "! kubectl delete --all isvc -n $NAMESPACE"
   ],
   "outputs": [],
   "metadata": {
    "scrolled": true
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "source": [
    "! kubectl delete pod --field-selector=status.phase==Succeeded -n $NAMESPACE"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}